package key_hash

import (
	"context"
	"errors"
	"gitee.com/vrv_media/go-micro-framework/registry"
	"gitee.com/vrv_media/go-micro-framework/registry/balance"
	"math/rand"
	"sync"
)

// 基于传入值的key的hash的算法：
// 相同的传入的key分配到相同的node上面
// 在使用这个算法的时候，需要在传入的context中传入需要的key
const (
	// ContextKey 这个值用来从context中获取到值
	ContextKey = "Key-hash-key"
)

var (
	NoContextKeyError = errors.New("context not has key ,key str = " + ContextKey)
)

type KeyHashBalancer struct {
	lock sync.RWMutex
	// 这里面存储已经使用过的node,用node的key作为主键，表示这些node已经使用过
	usedMap map[string]registry.Node
	// 这个存放被选择的key
	pickedMap map[string]registry.Node
}

func NewKeyHashBalancer() registry.Balancer {
	return &KeyHashBalancer{
		usedMap:   make(map[string]registry.Node),
		pickedMap: make(map[string]registry.Node),
	}
}

func (k *KeyHashBalancer) Pick(ctx context.Context, nodes []registry.Node) (selected registry.Node, err error) {
	if len(nodes) == 0 {
		return nil, balance.NoNodeError
	}
	key := ctx.Value(ContextKey).(string)
	if key == "" {
		return nil, NoContextKeyError
	}
	// 获取已经存在的值
	k.lock.RLock()
	if value, ok := k.pickedMap[key]; ok {
		k.lock.RUnlock()
		return value, nil
	}
	k.lock.RUnlock()
	// 还没有获取过的时候，优先找出没有使用过的node
	noUsedNodes := k.pickNoUsedNodes(nodes)
	// 如果都使用过了，随机选择一个
	if len(noUsedNodes) == 0 {
		selected = nodes[rand.Intn(len(nodes))]
	} else {
		selected = noUsedNodes[rand.Intn(len(noUsedNodes))]
	}
	k.lock.Lock()
	k.pickedMap[key] = selected
	k.usedMap[selected.Key()] = selected
	k.lock.Unlock()
	return selected, nil
}

func (k *KeyHashBalancer) pickNoUsedNodes(nodes []registry.Node) []registry.Node {
	k.lock.RLock()
	defer k.lock.RUnlock()
	noUsedNodes := make([]registry.Node, 0, len(nodes))
	for _, node := range nodes {
		if _, ok := k.usedMap[node.Key()]; !ok {
			noUsedNodes = append(noUsedNodes, node)
		}
	}
	return noUsedNodes
}
